
#include "appinc.h"
#include "mqtt.h"


static uint32_t mqtt_timer_left(m_timer *timer)
{
	uint64_t cur;
	uint32_t diff;
	
	cur = libsys_get_uptime_ms();
	diff = cur - timer->start;
	
	if (diff >= timer->period)
		return 0;

	return timer->period - diff;
}


static void mqtt_timer_init(m_timer *timer, uint32_t ms)
{
	timer->period = ms;
	timer->start = libsys_get_uptime_ms();
}

static void mqtt_timer_reset(m_timer *timer)
{
	timer->start = libsys_get_uptime_ms();
}


static int getNextPacketId(mqtt_context_t *context) 
{
	context->pid++;
	if (context->pid == 0)
		context->pid = 1;

	return context->pid;
}


static void on_mqtt_disconnect(mqtt_context_t *context, const char *reason)
{
	if (reason && context->connected > 0)
		strncpy(context->reason, reason, sizeof(context->reason));
	context->connected = -1;
}


static int mqtt_ssl_connect(mqtt_context_t *context, mqtt_param_t *param)
{
#ifdef WITH_TLS
	BIO *bio;
	int ret;

	context->ssl_ctx = SSL_CTX_new(TLSv1_2_client_method());
	SSL_CTX_set_verify(context->ssl_ctx, SSL_VERIFY_NONE, NULL);
	context->ssl = SSL_new(context->ssl_ctx);
	
	bio = BIO_new_socket(context->socket, BIO_NOCLOSE);
	if(bio) {
		SSL_set_bio(context->ssl, bio, bio);
		ret = SSL_connect(context->ssl);
		if (ret == 1)
			return 0;
	}
	SSL_free(context->ssl);
	SSL_CTX_free(context->ssl_ctx);
	context->ssl = NULL;
	context->ssl_ctx = NULL;
#endif
	
	return -EFAULT;
}


static int mqtt_socket_connect(mqtt_context_t *context, mqtt_param_t *param)
{
#if 0
	int ret, sockfd, opval = 1;
	struct sockaddr_in servaddr;
	struct addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP, 0, NULL, NULL, NULL};
	struct addrinfo *ainfo, *rp;

	memset(&servaddr, 0, sizeof(servaddr));
	if (inet_aton(param->hostname, &servaddr.sin_addr) > 0) {
		servaddr.sin_family = AF_INET;
		servaddr.sin_port = htons(param->port);
	}else {
		char port_str[8];
		sprintf(port_str, "%d", param->port);
		ret = getaddrinfo(param->hostname, port_str, &hints, &ainfo);
		if (ret != 0)
			return -1;
		for(rp = ainfo; rp != NULL; rp = rp->ai_next) {
			if (rp->ai_family == AF_INET) {
				memcpy(&servaddr, rp->ai_addr, sizeof(servaddr));
				break;
			}
		}
		freeaddrinfo(ainfo);
		LOG_I("%s ip:%s", param->hostname, inet_ntoa(servaddr.sin_addr));
	}

	sockfd = socket(AF_INET, SOCK_STREAM, 0);
	if (sockfd < 0) {
		LOG_D("%s: mqtt socket create fail", context->name);
		return -1;
	}
	ret= connect(sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr));
	if (ret < 0) {
		close(sockfd);
		return ret;
	}
#endif
	char port_str[8];
	int sockfd, opval = 1;
	
	sprintf(port_str, "%d", param->port);
	sockfd = libsocket_connect_tcp(param->hostname, port_str, 5000);
	if (sockfd < 0)
		return sockfd;

	setsockopt(sockfd ,IPPROTO_TCP, TCP_NODELAY, &opval, sizeof(opval));
	context->socket = sockfd;

	return 0;
}


static int mqtt_read(mqtt_context_t *context, uint8_t *buffer, int len)
{
	int ret, recvlen = 0, retry = 0;
	char reason[32];

	while (recvlen < len) {
#ifdef WITH_TLS
		if (context->ssl)
			ret = SSL_read(context->ssl, buffer + recvlen, len - recvlen);
		else
#endif
		ret = recv(context->socket, buffer + recvlen, len - recvlen, 0);
		if (ret < 0 || (ret == 0 && (recvlen == 0 || retry > 10))) {
			LOG_W("%s: %s recv fail, ret=%d, recvlen=%d", context->name, __func__, ret, recvlen);
			snprintf(reason, sizeof(reason), "read(%d)", ret);
			on_mqtt_disconnect(context, reason);
			if (ret == 0)
				ret = -ENETRESET;
			return ret;
		} else if (ret == 0) {
			sleep(1);
			retry++;
		}
		recvlen += ret;
	}

	return recvlen;
}


static int mqtt_write(mqtt_context_t *context, uint8_t *buffer, int len)
{
	int ret;
	char reason[32];

#ifdef WITH_TLS
	if (context->ssl)
		ret = SSL_write(context->ssl, buffer, len);
	else
#endif
	ret = write(context->socket, buffer, len);
	if (ret < 0) {
		snprintf(reason, sizeof(reason), "write(%d)", ret);
		on_mqtt_disconnect(context, reason);
		return ret;
	}

	return ret;
}


static int sendPacket(mqtt_context_t *context, uint8_t *sendbuf, int length)
{
	int ret, sent = 0;

	while (sent < length) {
		ret = mqtt_write(context, sendbuf + sent, length);
		if (ret < 0)
			break;
		sent += ret;
	}
	ret = -EFAULT;
	if (sent == length)
		ret = 0;

	if (ret == 0)
		mqtt_timer_reset(&context->send_timer);

	LOG_D("%s: sendPacket: sent=%d", context->name, sent);
	return ret;
}


static int readPacket(mqtt_context_t *context, m_timer *timer)
{
	int ret = -EFAULT, len, rem_len = 0, multiplier = 1;
	uint32_t timeout = 0;
	fd_set fdset;
	struct timeval tv, *tvp = NULL;

	FD_ZERO(&fdset);
	FD_SET(context->socket, &fdset);

	if (timer) {
		timeout = mqtt_timer_left(timer);
		tv.tv_sec = timeout / 1000;
		tv.tv_usec = (timeout % 1000) * 1000;
		tvp = &tv;
	}
	ret = select(context->socket + 1, &fdset, NULL, NULL, tvp);
	if (ret < 0) {
		LOG_W("%s: %s select ret:%d", context->name, __func__, ret);
		on_mqtt_disconnect(context, "select");
		return ret;
	}
	if (ret == 0)
		return 0;

	/* 1. read the header byte.  This has the packet type in it */
	len = mqtt_read(context, context->readbuf, 1);
	if (len != 1)
		return -EFAULT;
	/* 2. read the remaining length.  This is variable in itself */
	while (1) {
		if (mqtt_read(context, context->readbuf+len, 1) != 1)
			return -EFAULT;
		rem_len += (context->readbuf[len] & 127) * multiplier;
		if (!(context->readbuf[len++] & 128))
			break;
		if (len > 4)
			return -EFAULT;
		multiplier *= 128;
	}
	if (rem_len > 0) {
		/* 3. read the rest of the buffer using a callback to supply the rest of the data */
		if (rem_len + len > context->readbuf_size) {
			int new_size = (rem_len + len) * 2;
			if (new_size > RECIVER_MAX_SIZE)
				new_size = RECIVER_MAX_SIZE;
			if (new_size > context->readbuf_size) {
				uint8_t *buf = realloc(context->readbuf, new_size);
				if (buf) {
					context->readbuf = buf;
					context->readbuf_size = new_size;
				}
			}
		}
		if (rem_len + len > context->readbuf_size) {
			LOG_W("%s: recive package too large", context->name);
			while (rem_len > 0) {
				len = rem_len;
				if (len > 1024)
					len = 1024;
				ret = mqtt_read(context, context->readbuf, len);
				if (ret < 0)
					return ret;
				rem_len-= ret;
			}
			return 0;
		}
		ret = mqtt_read(context, context->readbuf + len, rem_len);
		if (ret != rem_len)
			return -EFAULT;
	}
	len += rem_len;
	LOG_D("%s: readPacket: read:%d", context->name, len);

	return len;
}


static void keepalive(mqtt_context_t *context)
{
	int ret = -EFAULT, len;
	uint8_t sendbuf[128];

	if (context->connected <= 0 || context->keep_ms == 0)
		return;
	
	if (mqtt_timer_left(&context->recive_timer) && mqtt_timer_left(&context->send_timer))
		return;
	
	if (context->pingreq) {
		LOG_W("%s: disconnect by timer", context->name);
		on_mqtt_disconnect(context, "keepalive");
		return;
	}

	LOG_D("%s: keepalive", context->name);
	len = MQTTSerialize_pingreq(sendbuf, sizeof(sendbuf));
	if (len > 0)
		ret = sendPacket(context, sendbuf, len);
	if (ret == 0)
		context->pingreq = 1;
	mqtt_timer_reset(&context->recive_timer);
}


/* 返回值，当前接收的报文类型 */
static int mqtt_cycle(mqtt_context_t *context, m_timer *timer)
{
	int len = 0, ret;
	int qos;
	MQTTHeader header = {0};

	ret = readPacket(context, timer);
	if (ret < 0){
		LOG_I("read packet failed");
		return ret;
	}
	if (ret == 0) {
		keepalive(context);
		return 0;
	}
	
	header.byte = context->readbuf[0];
	switch (header.bits.type) {
		case CONNACK: {
			unsigned char connack_ret = 255;
			char sessionPresent = 0;
			LOG_D("%s: RX CONNACK", context->name);
			if (MQTTDeserialize_connack((unsigned char*)&sessionPresent, &connack_ret, context->readbuf, context->readbuf_size) != 1)
				return -EFAULT;
			if (connack_ret == 0) {
				ret = 0;
				context->connected = 1;
				if (sessionPresent && context->not_clean)
					context->subscribed = 1;
			}
			break;
		}
		
		case PUBACK:
			LOG_D("%s: RX PUBACK", context->name);
			break;

		case SUBACK: {
			int count = 0, grantedQoS = -1;
			unsigned short mypacketid;
			LOG_D("%s: RX SUBACK", context->name);
			if (MQTTDeserialize_suback(&mypacketid, 1, &count, &grantedQoS, context->readbuf, context->readbuf_size) != 1)
				return -EFAULT;
			if (grantedQoS >= 0 && grantedQoS <= 2) {
				ret = 0;
				context->subscribed = 1;
			}
			break;
		}

		case PUBLISH: {
			mqtt_message_t msg;
			LOG_D("%s: RX PUBLISH", context->name);
			if (MQTTDeserialize_publish(&msg.dup, &qos, &msg.retained, &msg.pid, &msg.topic,
				&msg.payload, &msg.payloadlen, context->readbuf, context->readbuf_size-1) != 1)
				return -EFAULT;
			msg.qos = qos;
			msg.payload[msg.payloadlen] = '\0';
			
			if (context->on_message != NULL)
				context->on_message(&msg, context->msg_arg);
			if (msg.qos != QOS0 && context->connected > 0) {
				uint8_t sendbuf[128];
				ret = -EFAULT;
				if (msg.qos == QOS1)
					len = MQTTSerialize_ack(sendbuf, sizeof(sendbuf), PUBACK, 0, msg.pid);
				else if (msg.qos == QOS2)
					len = MQTTSerialize_ack(sendbuf, sizeof(sendbuf), PUBREC, 0, msg.pid);
				if (len > 0)
					ret = sendPacket(context, sendbuf, len);
				if (ret != 0)
					return ret;
			}
			break;
		}
		case PUBREC:
		{
			unsigned short mypacketid;
			unsigned char dup, type;
			uint8_t sendbuf[128];
			
			LOG_D("%s: RX PUBREC", context->name);
			ret = -EFAULT;
			if (MQTTDeserialize_ack(&type, &dup, &mypacketid, context->readbuf, context->readbuf_size) != 1)
				return -EFAULT;
			len = MQTTSerialize_ack(sendbuf, sizeof(sendbuf), PUBREL, 0, mypacketid);
			if (len > 0)
				ret = sendPacket(context, sendbuf, len);
			if (ret != 0)
				return ret;
			break;
		}
		case PUBCOMP:
			LOG_D("%s: RX PUBCOMP", context->name);
			break;
		case PINGRESP:
			LOG_D("%s: RX PINGRESP", context->name);
			context->pingreq = 0;
			break;
		case DISCONNECT:
			LOG_I("%s: RX DISCONNECT", context->name);
			on_mqtt_disconnect(context, "remote");
			break;
	}
	mqtt_timer_reset(&context->recive_timer);

	return header.bits.type;
}


static int wait_packet(mqtt_context_t *context, int packet_type)
{
	int ret = -1;
	m_timer timer;

	mqtt_timer_init(&timer, MQTT_TIMEOUT_MS);

	while (1) {
		if (!mqtt_timer_left(&timer)) 
			break;
		ret = mqtt_cycle(context, &timer);
		if (ret < 0 && context->connected < 0)
			break;
		if (ret == packet_type) {
			ret = 0;
			break;
		}
	}

	return ret;
}



int hmqtt_connect(mqtt_context_t *context, mqtt_param_t *param)
{
	MQTTPacket_connectData options = MQTTPacket_connectData_initializer;
	int ret, len;
	uint8_t sendbuf[128];
	
	if (context == NULL || param == NULL)
		return -EINVAL;
	
	LOG_I("%s: start mqtt connect: %s:%d, clientid:%s", context->name, param->hostname, param->port, param->clientid);

	ret = mqtt_socket_connect(context, param);
	if (ret < 0) {
		LOG_I("%s: mqtt socket connect fail", context->name);
		return -ENOTCONN;
	}
	if (param->ssl_enable) {
		ret = mqtt_ssl_connect(context, param);
		if (ret < 0) {
			LOG_I("%s: mqtt ssl connect fail", context->name);
			close(context->socket);
			context->socket = -1;
			return -ENOTCONN;
		}
	}
	LOG_I("%s: mqtt connect success", context->name);
	
	options.willFlag = 0;
	options.MQTTVersion = 4;
	options.clientID.cstring = param->clientid;
	if (strlen(param->username))
		options.username.cstring = param->username;
	if (strlen(param->password))
		options.password.cstring = param->password;
	options.keepAliveInterval = param->keepalive;
	options.cleansession = param->cleansession;
	context->keep_ms = param->keepalive * 1000;
	context->pingreq = 0;
	if (context->keep_ms) {
		mqtt_timer_init(&context->send_timer, context->keep_ms);
		mqtt_timer_init(&context->recive_timer, context->keep_ms);
	}
	
	ret = -EFAULT;
	len = MQTTSerialize_connect(sendbuf, sizeof(sendbuf), &options);
	if (len > 0)
		ret = sendPacket(context, sendbuf, len);
	if (ret != 0) 
		goto EXIT;

	wait_packet(context, CONNACK);

EXIT:
	if (context->connected == 1) {
		context->not_clean = !param->cleansession;
		LOG_I("%s: connect to %s:%d success", context->name, param->hostname, param->port);
		return 0;
	} else {
		
		if (context->socket >= 0)
			close(context->socket);
		context->socket = -1;
		context->connected = 0;
		LOG_I("%s: connect to %s:%d fail", context->name, param->hostname, param->port);
		return -EFAULT;
	}
}


int hmqtt_subscribe(mqtt_context_t *context, const char *topic, int qos)
{
	MQTTString sub_topic = MQTTString_initializer;
	int rc = -EFAULT, len;
	uint8_t sendbuf[MQTT_TOPIC_LEN + 32];

	if (context == NULL || topic == NULL || context->connected <= 0)
		return -EINVAL;

	LOG_D("%s: subscribe %s ...", context->name, topic);
	sub_topic.cstring = (char *)topic;
	
	len = MQTTSerialize_subscribe(sendbuf, sizeof(sendbuf), 0, getNextPacketId(context), 1, &sub_topic, &qos);
	if (len > 0)
		rc = sendPacket(context, sendbuf, len);
	if (rc != 0)
		goto EXIT;

	wait_packet(context, SUBACK);
	
EXIT:
	if (context->subscribed) {
		LOG_D("%s: subscribe %s success", context->name, topic);
		return 0;
	} else {
		LOG_D("%s: subscribe %s fail", context->name, topic);
		return -EFAULT;
	}
}


int hmqtt_publish(mqtt_context_t *context, mqtt_message_t *message)
{
	int rc = -EFAULT, len;
	int sendbuf_size;
	uint8_t *sendbuf;

	if (context == NULL || message == NULL || context->connected <= 0)
		return -EINVAL;

	LOG_D("%s: publish %s, payloadlen=%d", context->name, message->topic.cstring, message->payloadlen);
	if (message->payload[0] == '{' && message->payload[message->payloadlen-1] == '}')
		LOG_D("%s: payload:%s", context->name, message->payload);

	pthread_mutex_lock(&context->mutex);
	sendbuf_size = message->payloadlen + MQTT_TOPIC_LEN + 32;
	sendbuf = malloc(sendbuf_size);
	if (sendbuf == NULL) {
		pthread_mutex_unlock(&context->mutex);
		return -ENOMEM;
	}
	if (message->qos == QOS1 || message->qos == QOS2)
		message->pid = getNextPacketId(context);
	len = MQTTSerialize_publish(sendbuf, sendbuf_size, 
				message->dup, message->qos, message->retained, message->pid, 
				message->topic, message->payload, message->payloadlen);
	if (len > 0)
		rc = sendPacket(context, sendbuf, len);
	free(sendbuf);
	pthread_mutex_unlock(&context->mutex);

	/* TODO. QOS1和QOS2暂时没有做重发机制 */
	
	if (rc == 0) 
		LOG_D("%s: publish success", context->name);
	else 
		LOG_D("%s: publish fail", context->name);

	return rc;
}

int hmqtt_publish_qos0(mqtt_context_t *context, const char *topic, void *payload, int payloadlen)
{
	mqtt_message_t message;

	memset(&message, 0, sizeof(message));
	message.topic.cstring = (char *)topic;
	message.payload = payload;
	message.payloadlen = payloadlen;

	return hmqtt_publish(context, &message);
}


int hmqtt_disconnect(mqtt_context_t *context, const char *reason)
{
	int len;
	uint8_t sendbuf[128];

	if (context == NULL || context->connected == 0)
		return -EINVAL;

	if (reason && context->connected > 0)
		strncpy(context->reason, reason, sizeof(context->reason));

	pthread_mutex_lock(&context->mutex);
	if (context->socket > 0) {
		LOG_I("%s: disonnect", context->name);
		if (context->connected > 0) {
			len = MQTTSerialize_disconnect(sendbuf, sizeof(sendbuf));
			if (len > 0)
				sendPacket(context, sendbuf, len);
		}
	}
	
	context->connected = 0;
	context->subscribed = 0;
#ifdef WITH_TLS
	if (context->ssl)
		SSL_free(context->ssl);
	if (context->ssl_ctx)
		SSL_CTX_free(context->ssl_ctx);
	context->ssl = NULL;
	context->ssl_ctx = NULL;
#endif
	if (context->socket >= 0)
		close(context->socket);
	
	context->socket = -1;
	pthread_mutex_unlock(&context->mutex);


	return 0; 
}


int hmqtt_loop(mqtt_context_t *context, uint32_t timeout)
{
	int ret = 0;
	uint32_t time_val, time_min;
	m_timer *timer = NULL;
	m_timer timer_exit, timer_wait;
	
	if (context == NULL)
		return -EINVAL;

	if (timeout == 0)
		timeout = -1;
	mqtt_timer_init(&timer_exit, timeout);

	context->wakeup = false;
	while (1) {
		time_min = -1;
		timer = NULL;
		if (context->keep_ms) {
			time_val = mqtt_timer_left(&context->send_timer);
			if (time_val < time_min)
				time_min = time_val;
			time_val = mqtt_timer_left(&context->recive_timer);
			if (time_val < time_min)
				time_min = time_val;
		}
		if (timeout != -1) {
			time_val = mqtt_timer_left(&timer_exit);
			if (time_val == 0)
				break;
			if (time_val < time_min)
				time_min = time_val;
		}
		if (time_min != -1) {
			mqtt_timer_init(&timer_wait, time_min);
			timer = &timer_wait;
		}
	
		ret = mqtt_cycle(context, timer);
		if (ret < 0 || context->connected <= 0 || context->wakeup)
			break;
	}
	
	return ret;
}



int hmqtt_init(mqtt_context_t *context, void (*on_message)(mqtt_message_t *, void *), void *arg)
{
	if (context == NULL)
		return -EINVAL;

	//初始分配4K内存
	context->readbuf_size = 4096;
	context->readbuf = malloc(context->readbuf_size);
	if (context->readbuf == NULL)
		return -ENOMEM;
	
	pthread_mutex_init(&context->mutex, NULL);

	context->socket = -1;
	context->on_message = on_message;
	context->msg_arg = arg;
	context->qos = 0;
	if (context->name[0] == '\0') {
		static int count;
		sprintf(context->name, "hmqtt%d", count++);
	}
	strcpy(context->reason, "init");

#ifdef WITH_TLS
	SSL_library_init();
#endif

	return 0;
}


int hmqtt_uninit(mqtt_context_t *context)
{
	if (context == NULL)
		return -EINVAL;

	if (context->readbuf)
		free(context->readbuf);

	return 0;
}
